# -*- coding: utf-8 -*-
"""
infer_piston_view_v5_2_plus.py
----------------------------------------
AI 模型自動工程圖推論（v5.2 Plus 版）
✅ 對應 piston_v5_2_best.pth
✅ 可由 PHP 呼叫: python infer_piston_view_v5_2_plus.py "@<tmpJson>"
✅ 產生 outputs/piston_ai_view_YYYYMMDD_HHMMSS.png
✅ 自動輸出 __IMG__= 與 SSIM/PSNR 結果
"""

import os, sys, json, datetime
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import cv2
from PIL import Image
from skimage.metrics import structural_similarity as ssim, peak_signal_noise_ratio as psnr

# ---------- 設定 ----------
ROOT = Path(r"C:\xampp\htdocs\cs_ai")
MODEL_PATH = ROOT / r"ai_models\piston_v5_2_best.pth"
OUT_DIR = ROOT / "outputs"
OUT_DIR.mkdir(parents=True, exist_ok=True)

W, H = 960, 544
CHANNELS = 1

# ---------- 模型定義 ----------
def conv_block(in_ch, out_ch):
    return nn.Sequential(
        nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
        nn.BatchNorm2d(out_ch),
        nn.LeakyReLU(0.1, inplace=True),
        nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
        nn.BatchNorm2d(out_ch),
        nn.LeakyReLU(0.1, inplace=True),
    )

class UNetAE(nn.Module):
    def __init__(self, ch=1, base=32):
        super().__init__()
        self.c1 = conv_block(ch, base);   self.p1 = nn.MaxPool2d(2)
        self.c2 = conv_block(base, base*2); self.p2 = nn.MaxPool2d(2)
        self.c3 = conv_block(base*2, base*4); self.p3 = nn.MaxPool2d(2)
        self.c4 = conv_block(base*4, base*8); self.p4 = nn.MaxPool2d(2)
        self.bn = conv_block(base*8, base*16)
        self.up4 = nn.ConvTranspose2d(base*16, base*8, 2, 2); self.d4 = conv_block(base*16, base*8)
        self.up3 = nn.ConvTranspose2d(base*8,  base*4, 2, 2); self.d3 = conv_block(base*8,  base*4)
        self.up2 = nn.ConvTranspose2d(base*4,  base*2, 2, 2); self.d2 = conv_block(base*4,  base*2)
        self.up1 = nn.ConvTranspose2d(base*2,  base,   2, 2); self.d1 = conv_block(base*2,  base)
        self.out = nn.Conv2d(base, 1, 1)
    def forward(self, x):
        c1 = self.c1(x); p1 = self.p1(c1)
        c2 = self.c2(p1); p2 = self.p2(c2)
        c3 = self.c3(p2); p3 = self.p3(c3)
        c4 = self.c4(p3); p4 = self.p4(c4)
        bn = self.bn(p4)
        u4 = self.up4(bn); d4 = self.d4(torch.cat([u4, c4], 1))
        u3 = self.up3(d4); d3 = self.d3(torch.cat([u3, c3], 1))
        u2 = self.up2(d3); d2 = self.d2(torch.cat([u2, c2], 1))
        u1 = self.up1(d2); d1 = self.d1(torch.cat([u1, c1], 1))
        return torch.sigmoid(self.out(d1))

# ---------- JSON 載入 ----------
def load_json_from_arg():
    if len(sys.argv) > 1 and sys.argv[1].startswith('@'):
        path = sys.argv[1][1:]
    else:
        print("❌ 缺少 JSON 輸入參數。")
        sys.exit(1)
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

# ---------- 模擬輸入影像 ----------
def json_to_tensor(params, W=960, H=544):
    img = np.ones((H, W), np.float32) * 0.85
    cx, cy = W//2, H//2
    d = params.get("diameter", 60)
    h = params.get("length", 70)
    ring = int(params.get("ring_count", 3))
    bore = params.get("bore", 80)

    # 繪製灰階結構
    cv2.rectangle(img, (cx-int(d/2), cy-int(h/2)), (cx+int(d/2), cy+int(h/2)), 0.4, 2)
    cv2.circle(img, (cx, cy), int(d/2), 0.5, 1)
    for i in range(ring):
        y = int(cy - h/2 + 15 + i*10)
        cv2.line(img, (cx-int(bore/2), y), (cx+int(bore/2), y), 0.3, 1)
    img = cv2.GaussianBlur(img, (5,5), 0)
    x = torch.from_numpy(img[None,None,:,:]).float()
    return x

# ---------- 主函式 ----------
def main():
    params = load_json_from_arg()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"✅ 使用裝置：{device}")

    model = UNetAE(ch=CHANNELS, base=32).to(device)
    if not MODEL_PATH.exists():
        print(f"❌ 找不到模型檔案：{MODEL_PATH}")
        sys.exit(1)
    model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
    model.eval()

    # 輸入資料
    x = json_to_tensor(params, W, H).to(device)
    with torch.no_grad(), torch.amp.autocast("cuda", enabled=(device.type=="cuda")):
        y = model(x).float()
    img = y[0,0].cpu().numpy()
    img8 = np.clip(img*255, 0, 255).astype(np.uint8)

    # 儲存圖片
    now = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    out_path = OUT_DIR / f"piston_ai_view_{now}.png"
    Image.fromarray(img8).save(out_path)

    print(f"__IMG__={out_path}")
    print(f"✅ 已輸出推論結果：{out_path}")

    # 若有真實對應圖，可做品質比較
    gt_path = params.get("gt_image_path")
    if gt_path and os.path.exists(gt_path):
        gt = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)
        gt = cv2.resize(gt, (W, H))
        s = ssim(gt, img8, data_range=255)
        p = psnr(gt, img8, data_range=255)
        print(f"📊 SSIM={s:.4f}, PSNR={p:.2f} dB")

if __name__ == "__main__":
    main()
